{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 安装第三方库\n",
    "!pip install torcheval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7f9749c6f170>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from utils import Linear\n",
    "import pandas as pd\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "# 固定随机数生成种子，使得计算结果可以复现\n",
    "torch.manual_seed(1024)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_data(path):\n",
    "    \"\"\"\n",
    "    使用pandas读取数据\n",
    "    \"\"\"\n",
    "    data = pd.read_csv(path)\n",
    "    cols = [\"age\", \"education_num\", \"capital_gain\", \"capital_loss\", \"hours_per_week\", \"label\"]\n",
    "    return data[cols]\n",
    "\n",
    "\n",
    "if os.name == \"nt\":\n",
    "    data_path = \".\\\\data\\\\adult.data\"\n",
    "else:\n",
    "    data_path = \"./data/adult.data\"\n",
    "data = read_data(data_path)\n",
    "data[\"label_code\"] = pd.Categorical(data.label).codes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>education_num</th>\n",
       "      <th>capital_gain</th>\n",
       "      <th>capital_loss</th>\n",
       "      <th>hours_per_week</th>\n",
       "      <th>label</th>\n",
       "      <th>label_code</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>39</td>\n",
       "      <td>13</td>\n",
       "      <td>2174</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>50</td>\n",
       "      <td>13</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>13</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>38</td>\n",
       "      <td>9</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>53</td>\n",
       "      <td>7</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>28</td>\n",
       "      <td>13</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32556</th>\n",
       "      <td>27</td>\n",
       "      <td>12</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>38</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32557</th>\n",
       "      <td>40</td>\n",
       "      <td>9</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32558</th>\n",
       "      <td>58</td>\n",
       "      <td>9</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32559</th>\n",
       "      <td>22</td>\n",
       "      <td>9</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>20</td>\n",
       "      <td>&lt;=50K</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32560</th>\n",
       "      <td>52</td>\n",
       "      <td>9</td>\n",
       "      <td>15024</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>&gt;50K</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>32561 rows × 7 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       age  education_num  capital_gain  capital_loss  hours_per_week   label  \\\n",
       "0       39             13          2174             0              40   <=50K   \n",
       "1       50             13             0             0              13   <=50K   \n",
       "2       38              9             0             0              40   <=50K   \n",
       "3       53              7             0             0              40   <=50K   \n",
       "4       28             13             0             0              40   <=50K   \n",
       "...    ...            ...           ...           ...             ...     ...   \n",
       "32556   27             12             0             0              38   <=50K   \n",
       "32557   40              9             0             0              40    >50K   \n",
       "32558   58              9             0             0              40   <=50K   \n",
       "32559   22              9             0             0              20   <=50K   \n",
       "32560   52              9         15024             0              40    >50K   \n",
       "\n",
       "       label_code  \n",
       "0               0  \n",
       "1               0  \n",
       "2               0  \n",
       "3               0  \n",
       "4               0  \n",
       "...           ...  \n",
       "32556           0  \n",
       "32557           1  \n",
       "32558           0  \n",
       "32559           0  \n",
       "32560           1  \n",
       "\n",
       "[32561 rows x 7 columns]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 展示数据\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LogitRegression:\n",
    "    \n",
    "    def __init__(self, neg, pos):\n",
    "        '''\n",
    "        定义逻辑回归模型的结构\n",
    "        参数\n",
    "        ----\n",
    "        neg ：Linear，负面的偏好，模型的形状为(k, 1)\n",
    "        pos ：Linear，正面的偏好，模型的形状为(k, 1)\n",
    "        '''\n",
    "        self.pos = pos\n",
    "        self.neg = neg\n",
    "        \n",
    "    def __call__(self, x):\n",
    "        '''\n",
    "        逻辑回归模型的向前传播\n",
    "        参数\n",
    "        ----\n",
    "        x ：torch.FloatTensor，形状为(n, k)，其中n表示批量数据的大小，k表示特征的个数\n",
    "        '''\n",
    "        self.out = torch.concat((self.neg(x), self.pos(x)), dim=1)\n",
    "        return self.out  # (n, 2)\n",
    "    \n",
    "    def parameters(self):\n",
    "        return self.neg.parameters() + self.pos.parameters()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义模型\n",
    "pos = Linear(5, 1)\n",
    "neg = Linear(5, 1)\n",
    "model = LogitRegression(neg, pos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 准备数据\n",
    "x = torch.tensor(data[['age', 'education_num', 'capital_gain', 'capital_loss', 'hours_per_week']].values).float()\n",
    "x = F.normalize(x)                           # (32561, 5)\n",
    "y = torch.tensor(data['label_code']).long()  # (32561)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[ 1.2665, -1.7305]]), tensor([[0.9524, 0.0476]]), tensor([0]))"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 使用模型\n",
    "## 注意，模型输入数据的形状一定要是(n, 5)\n",
    "logits = model(x[[1]])            # (1, 2)\n",
    "probs = F.softmax(logits, dim=1)  # (1, 2)\n",
    "pred = torch.where(probs[:, 1] > 0.5, 1, 0)\n",
    "logits, probs, pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(0.0487), tensor(0.0487))"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 计算模型在单点的损失\n",
    "loss = F.cross_entropy(logits, y[[1]])\n",
    "# cross_entropy的具体实现过程\n",
    "-probs[torch.arange(1), y[[1]]].log().mean(), loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 对于模型参数，需要记录它们的梯度（为反向传播做准备）\n",
    "for p in model.parameters():\n",
    "    p.requires_grad = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step      0/20000, loss:  0.6580\n",
      "step   2000/20000, loss:  0.5092\n",
      "step   4000/20000, loss:  0.5066\n",
      "step   6000/20000, loss:  0.5037\n",
      "step   8000/20000, loss:  0.4958\n",
      "step  10000/20000, loss:  0.5046\n",
      "step  12000/20000, loss:  0.5015\n",
      "step  14000/20000, loss:  0.4952\n",
      "step  16000/20000, loss:  0.5086\n",
      "step  18000/20000, loss:  0.5086\n"
     ]
    }
   ],
   "source": [
    "# 标准随机梯度下降法的超参数\n",
    "max_steps = 20000\n",
    "batch_size = 3000\n",
    "lossi = []\n",
    "\n",
    "for i in range(max_steps):\n",
    "    # 构造批次训练数据\n",
    "    ix = torch.randint(0, x.shape[0], (batch_size,))\n",
    "    xb = x[ix]\n",
    "    yb = y[ix]\n",
    "    # 向前传播\n",
    "    logits = model(xb)\n",
    "    loss = F.cross_entropy(logits, yb)\n",
    "    # 反向传播\n",
    "    loss.backward()\n",
    "    # 更新模型参数\n",
    "    ## 学习速率衰减\n",
    "    learning_rate = 0.1 if i < 10000 else 0.01\n",
    "    with torch.no_grad():\n",
    "        for p in model.parameters():\n",
    "            p -= learning_rate * p.grad\n",
    "            p.grad = None\n",
    "        \n",
    "    # 统计数据\n",
    "    if i % 2000 == 0:\n",
    "        print(f'step {i: 6d}/{max_steps}, loss: {loss.item(): .4f}')\n",
    "    lossi.append(loss.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7f975098d940>]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/OQEPoAAAACXBIWXMAAAsTAAALEwEAmpwYAAAsqklEQVR4nO3dd3wUdfoH8M9DQu+9Q+gIIi0gAlJUuid24c6CDVHxTj31UNGznlhQf5yoZwErRVARBaQpVVpC74QQQkILCU1a2vf3x8xuZndndmeT3ewyfN6vV17ZnZ2deXZ25plvm1lRSoGIiJyrRKQDICKi8GKiJyJyOCZ6IiKHY6InInI4JnoiIoeLjXQA3mrUqKHi4uIiHQYR0UUlMTHxmFKqptlrUZfo4+LikJCQEOkwiIguKiKy3+o1Nt0QETkcEz0RkcMx0RMRORwTPRGRwzHRExE5HBM9EZHDMdETETmcoxL9/G2HcfT0+UiHQUQUVRyT6M/n5OGhrxNx52drIh0KEVFUcUyiz9d/QOVA1rkIR0JEFF0ck+iJiMic4xK9An8akYjIyDGJXiCRDoGIKCo5JtETEZE5xyV6xZYbIiIPjkn0wpYbIiJTjkn0LMkTEZlzTKJ3YcmeiMiT4xI9S/ZERJ4ck+hZkiciMueYRE9EROYcl+jZckNE5MlxiT47Nz/SIRARRRXHJPrUrLORDoGIKCo5JtFztA0RkTnnJHq2zhMRmXJMos9n0zwRkSnHJPpq5UtFOgQioqjkmERfghdMERGZckyi5++OEBGZc06iJyIiU7YSvYgMFJFdIpIkImMs5rldRLaLyDYRmWKYniciG/W/2aEKnIiI7IkNNIOIxACYCKAfgDQA60RktlJqu2GeFgCeBdBDKXVcRGoZFnFOKdUhtGETEZFddkr0XQEkKaWSlVLZAKYBGOo1z4MAJiqljgOAUupoaMMkIqLCspPo6wM4YHiepk8zagmgpYisFJHVIjLQ8FoZEUnQp99otgIRGanPk5CRkRFM/EREFEDAppsgltMCQB8ADQAsE5F2SqkTABorpdJFpCmA30Rki1Jqr/HNSqlPAHwCAPHx8bzElYgohOyU6NMBNDQ8b6BPM0oDMFsplaOU2gdgN7TED6VUuv4/GcASAB2LGDMREQXBTqJfB6CFiDQRkVIAhgHwHj0zC1ppHiJSA1pTTrKIVBWR0obpPQBsBxERFZuATTdKqVwRGQ1gPoAYAJOUUttE5BUACUqp2fpr/UVkO4A8AE8rpTJFpDuA/4lIPrSTyjjjaB0iIgo/W230Sqm5AOZ6TXvR8FgBeFL/M87zB4B2RQ8zMOGlsUREpnhlLBGRwzHRExE5HBM9EZHDMdETETmcYxJ9jQr84REiIjOOSfQiHHVDRGTGMYmeiIjMMdETETkcEz0RkcMx0RMRORwTPRGRwzHRExE5HBM9EZHDMdETETkcEz0RkcMx0RMRORwTPRGRwzHRExE5HBM9EZHDMdETETkcEz0RkcMx0RMRORwTPRGRwzHRExE5HBM9EZHDMdETETkcEz0RkcMx0RMRORwTPRGRwzHRExE5HBM9EZHDOTLRn7mQG+kQiIiihiMT/VMzNkU6BCKiqOHIRL8+9XikQyAiihqOTPRERFTAkYleqUhHQEQUPZyZ6CMdABFRFHFmomemJyJys5XoRWSgiOwSkSQRGWMxz+0isl1EtonIFMP0e0Rkj/53T6gC94+ZnojIJTbQDCISA2AigH4A0gCsE5HZSqnthnlaAHgWQA+l1HERqaVPrwbg3wDioWXfRP29HBZDRFRM7JTouwJIUkolK6WyAUwDMNRrngcBTHQlcKXUUX36AAALlVJZ+msLAQwMTejW2HRDRFTATqKvD+CA4XmaPs2oJYCWIrJSRFaLyMAg3gsRGSkiCSKSkJGRYT96IiIKKFSdsbEAWgDoA2A4gE9FpIrdNyulPlFKxSul4mvWrFnkYFigJyIqYCfRpwNoaHjeQJ9mlAZgtlIqRym1D8BuaInfzntDLutMdrhXQUR00bCT6NcBaCEiTUSkFIBhAGZ7zTMLWmkeIlIDWlNOMoD5APqLSFURqQqgvz6NiIiKScBRN0qpXBEZDS1BxwCYpJTaJiKvAEhQSs1GQULfDiAPwNNKqUwAEJFXoZ0sAOAVpVRWOD4IERGZC5joAUApNRfAXK9pLxoeKwBP6n/e750EYFLRwiQiosJy5JWxRERUgImeiMjhmOiJiByOiZ6IyOGY6ImIHI6JnojI4ZjoiYgcjomeiMjhHJvoT57NiXQIRERRwbGJ/tiZC5EOgYgoKjg20RMRkcaxiV4iHQARUZRwbqIXpnoiIsDBiZ6IiDRM9EREDueoRF+lXEn34zmbD0YwEiKi6OGoRJ+XV/Cz4CuTMiMYCRFR9HBUoq9SvmTgmYiILjGOSvR1K5d1P+agGyIijaMS/VP9W0U6BCKiqOOoRF861lEfh4goJBybGdl0Q0SkcW6i500QiIgAOCzRG0vxK5KORS4QIqIo4qhET0REvhyV6NlcQ0Tky1GJnoiIfDHRExE5nKMSPYdUEhH5clSiJyIiX0z0REQO56hEf1ndSpEOgYgo6jgq0ceUYCM9EZE3RyV6b6fO50Q6BCKiiHN0or//i3WRDoGIKOIcnejXpRyPdAi25OblQykVeEYiokJwdKK/GOTm5aP58/Pw+pwdkQ6FiByKiT7CcvO1kvzXq/dHOBIicipbiV5EBorILhFJEpExJq+PEJEMEdmo/z1geC3PMH12KIO3Y2ZiWnGv0m1L2knM2pBua1423BBRuARM9CISA2AigEEA2gAYLiJtTGadrpTqoP99Zph+zjD9htCEbd9TMzaZTldKIScvP6zr/ssHK/D49I32ZmamjxoPfZ2AZ2aa7zfRpv97S/H+ot2RDoOinJ0SfVcASUqpZKVUNoBpAIaGN6zwe/GnbWjx/LxIh8H780Sh+duO4LuEyNUEg7H7yJ94f9GeSIdBUc5Ooq8P4IDheZo+zdstIrJZRGaKSEPD9DIikiAiq0XkRrMViMhIfZ6EjIwM28EXRbS1iSsW6YlCbtXeTGTnhrfmHqzcvPxiv8YnVJ2xPwOIU0pdAWAhgC8NrzVWSsUD+CuA90WkmfeblVKfKKXilVLxNWvWDFFI5pIz/sTxM9lhXUcw+GMpROGxNf0khn+6Gm/Mi64Rbc/M3IwrXlpQrOu0k+jTARhL6A30aW5KqUyl1AX96WcAOhteS9f/JwNYAqBjEeItlANZZ92Prxm/FAPeX1bcIQTEYfTkBDsPn8LQiStx5kJupENBpl6gSzr6Z4Qj8fSDzQEaoWQn0a8D0EJEmohIKQDDAHiMnhGRuoanNwDYoU+vKiKl9cc1APQAsD0UgQfDmOgB4OjpCxZzhk/acc8YElKycPp8jruNPjdfIW7MHKxPLb6LvE6czcaRU+eLbX0EzNqQjoSUrLAs+9T5HKSfOBeWZdv1xtyd2HTgBNaG6TM6SXFeJBkw0SulcgGMBjAfWgL/Tim1TUReERHXKJq/i8g2EdkE4O8ARujTLwOQoE//HcA4pVSxJ/rkY2eKe5UeJv6ehJ5v/o7le7T+h5PncnDrx6vwyLfrfeZduqt4+igAoOvri3HlfxYX2/oIeHz6Rtz68aqwLPv6CSvQY9xvYVn2xcgqkZ7PyQtbASc/XyHX5mi+fIVi6z+w1UavlJqrlGqplGqmlHpdn/aiUmq2/vhZpVRbpVR7pVRfpdROffofSql2+vR2SqnPw/dRNHUqlfGZNnbWVqQdP4u3ft1p+b5NB064EzGgJeNgZOfm49kfNuPoad8d6O35uwAAOw+dds8LADsOnQpqHaGWHebhpVS8Ur1qrpeCxP1ZePnnbX7nEa+hbSMmrw1bAefRKevR3OZovud/3IKWY4tn5J/jrozNtziLj56yAR8u2Wv5vqETV+Kuz9ciP19hxZ5jaP/yAizZddT2ehdsP4ypaw9g9JQNtt+jlG/bvFn0qZln0e0/i3EwTNXysbO2FPq9v249hAXbDocwGiL7bvloFSavTEHcmDk4edazcGbVMLI62X6zUm5eflCl7nlb7R8L09YdCDxTiDgw0ZtPv2Dzy+r82kLc+fkaAEDift/2cqUUnv9xCzYdOOE1Xfu/dp/1TuQaQul37LzJiWrqulQcPnUePwboxEk7fhb/t2hP0G1/36xOBQA8M3MTpqxJtfWerDPZOHMhF6O+WY+RXycGtT4zv+88ihV7jhV5OYXx08Z0/LA+dOPm8/MVHp2yHnsz7HUCXsjNw8zEtEvmxnbnsvOwfE8G8vMVEveHri1/15HTptOLMq7tlo/+KLZSdzg5LtFbHSxWzSRxY+Z4XAV53KtUkJzxp0dzzOkLufh2TSru/GyNx3yFufDJLNJgD/U/L+Ti8EktvlHfJOK9RbuxNyNwn4RZbeW7hDQ896N16X7TgRNIOqodTJ1eXYjr3l0aZLTW7v1infsEGyylFJbvySh0ovzHtI148rvQXQn79oJdmLP5EK4db2/7vL9oD56asQnzQ1Qz+nBJUkiWUxT3Tl6Hj5ea16DHztqKuz5fi+d+3IJbPlqFZbuLr18qWJvSTrofr07OxDt6M+zFxnGJ/qpm1YN+j7+rIK8ZvxRdXy9ozwtFoct1TsgzqX54Lz9xfxYW7zhiuazrJyxHtze0+M5l5+nLCBzkiMnB36t/6MSVuO7dgqGph0769kf8uvUw9md6nmhW7c3Eqr2ZOHr6fJGT2fR1qT7Ln7r2AO76fC1+2niwSMsOlc1pJ4KaP0MfBXbqXGiGJH672rpW9vOmg9h28KTl6wBw9PR5vDR7m+1OxQNZZ92fwWjcvII+sY+X7sXC7dp+nKTXdFwjzA6d9GySPJ+T5y68BCM3Px+fLkvG+RztOJiRcMA93cya5Ew8M3OTreMlL19h2Cer8cHv1idRs+M5Wjgu0TetWSFky9ro1TwDAO8t1O8r4lWCN37J62wOLTPr8PW+QvaWj1Zh9xHrJoCUzIIOOGOnU8qxM/h6VYqtOEJp1DeJ6PfeMpzLzsPHS/cibswcDP90NYZ/uhq3fPQHHvo6EamZ1p2Gz/5gXqM4fiYbSin86/stuOnDPzxeO6APXQ3X0MJtB08GdSW1RV6xZFYZTD9xzueE5u3YnwXJddKKfT6vn7mQi31eI84em7oBQyas8LvcF2ZtxRd/pOB3fQRYyrEz+Me0DcjO1X43YddhzyaSq9/6HV1eX4QHvkzAUovS+bh5O/HgVwmmr7nybH6+glIKj3673l14Ccb3iel4fe4OTNST8dwtWqFiZVKmTw0cAO74ZDW+S0izVXgz3pzwtMVVrcYTfNrxs7ZrmPn60OrJK32/w1BxXKLv1aJGyJa13KTN+Is/UtyPs3Pz3aUjY6K36pxcmZSJI6fOeyTkUCUn43Cxfu8tQ593luCFn7bZ7lD+deuhkMQBaNvlzV93epToAOBAlvZZX5+7HfO2HPIpyQHA1LW+pdElu46i46sL3d9HViGvbM788wKGTFjuc01DIEMmrMALs7YCAI6eOu8umVpJO1G40S/Gk3yPcb+h99tLsMhiXcf+vID41xa5n7/yS8Go5fQT53D6fA7u+nwN+r6zJOg4vEumfd5Zgp82HsT9X67D/5YlY8D7y/DQ175Je5Gfmqc/nyxPRn6+QtPn5uL+LxOweKfnPvvKz9tx9Vu/IS9fYWZiGvL0E8KczZ77rOsirdPnfWtGK5J8j2XXYWj8tEopvLtgl8/wy38abo746fLACbnnm7/b7mzN0UsGb8y1HhVYVI5L9M1CWKIP5LU52zFkwgpMX5fq0UbvPZzLZenuDPzlvyvwm2FHTj/umez8FQJEtGpu3Jg5PiWnpKN/mpYMdxwqKH2dPJuDD5ckmZY0Rn3jO6bfJSEly1ZH6XcJBTu2v2Scr4CHv12PWz+yN57c1Sm+IfWErfmt/LghHdsOnsKkFSm25v90WbLPtGGfrMaDXyWYNmucy87Dgayz7hMaAHdyyvdTrXcnHJNZHjApBefm5Qcs7X+fmIb1hdxerji8R5ss33PMffKev61wSf18Tp7PfpqccQa/bNGStvHYOHk2B//+aSsmrdyHA1nn8NWqFDw1YxOmrE3FrI3peHSK5z7rOlGeOpdj614yrjimrzuA33Zqn2dT2klM+C0pJMMvE/cfx/+W7g14ZW5x9MHHhn8Vxatq+VJhWW7WmWyUKel5XvxqlVad/9f3W/DeHe3d0z9ZlmxZUj96+oLHrZO9zwkK2nDK2BhBvSplfd7vSnbT1qaid8uC+wJNXZtqObT04Ilz6D7uN1QsE4vT53PRrn5ly88JaG3Gufn5OJB1Dj9vOujTbOE9jA0Adh85jWdmbnY/93eguUrE6SfO4ZvV+3Fnt8Yer6ccO4MRk9fiu1FXoVbFgusiAt347deth9GsZgUMvLyONr9SeG/RHgzv2hB1K5f1OOkBwJu/7sTVza1rgK/P3YEbO3rev2+/11j1fcfOYHPaCQztUB893/zNfdm9S7Pn5gLQDvon+rUwXU+g+x1N/D0Jj/ZtDkBrolm4/QhWJWf6fY+3jQdO4JTNa0OW66XfN3/diSFX1LWc7/DJ86hT2fe6FX9av/Cr+7Gx2fHxab7Dkt+avxPfGkaBHTmlNVW9PHsbqpkc566RdT9sSMfsTYH7a0QEUMo9ACFl3BDLY8iDUpi1IR2PT9+IrS8PQLmSMRi/cBfaN6jiMdvq5EzMTEzDG/N2ImXcEMvFuQttYbztleMSfbh0enWhx3Pv78T7YPWuVlo5YZI0e739OwDgg7963hZowuI9GH9bB9Pl/GKxPpGCk4OrSuvdxuqty+uL/L7e/hXfGzJdyPEs/S2xeYXv2Flb0aFhFY9pk1buQ0rmWczdfAgjejRxb9lzeieblS3pJzHqm0T3QTV66gbM2XwIK5OO4fuHu+N7fQil6+T60ZK9+MjPtRUA0Ef/LgBtu7iaNRS0tlpX08jQDvV9krzR1LWpmLvF/z6hoDUvXeM1Wuft+btwTetaqF6hlEcTTTBunLjS7+sLtx/Rrtju3MBdkg90AVa3Nxb7TWCBGGsMZhWe3DzPia5RPLn5yvQ2JsZ9LtdkgXFj5mBwuzru595NVFPWpOKyuhULYrKohU34raBDNv34OWSdycbE3333ozRDbf30+Rx8uyYVI69u6jOfq/M4Ozcfb8zbgWcHXWa63qJwXNNNpNj+gREvPtVPw77lffHV+Zx8e2PxDTL/vIBnf9jsMe21KPt9Wu8E6V2ociX4/y0taEqx0+HtOtleyM3zaK4KpuB0Jrvg5GIcWfLE9I1oZ7gDoZ0fsfHufN+QehxT16Z6NN3cPWmtaSf91vSTQXfyulhd8LNqbybixszBhtTjePCrBNMf6cn80/99oeLGzPH7+l/+u8I9JDcauDpozYydtcWj2TVQwQLQaplWo3qMXvtlB8bN2xnwR2I+t9H+Xxgs0RfSKZMOn1AI1DzhSih2b29sp+Mo0rz7DFyjnX7RS/Rmn+G2j1dh68sDUKF0rGkbp/fYbFcTCgB8tmIfPjMZpZKaeRaNqpezFbN3Deq8jaTgzTV6aHhX7eaw2w+dxLaD5td7PD1zs+l0O8zaiJ/7cQsqlon1iAOAz4VjfxbxLpRb0k96DMkNRiR+o8F4VNkpTCVnnLG1jU5f0E7extqAC9voL0EpAW7A9sR0rdQ1Z8shVC3CrQtCze6QUju2pGsjmRL2H/dbYvx29X68MW8nmtfy7YC/e9Ja92OBWF4xbfTTxnQ8dq15O3qoGU9uU9dqndjf+Bn/HqyXfi5o4hk8YbnP61PWpGJ410Y+070vHOv99pKQxRSs4r53j/cgimOnA4/uMrsxoRl/NYlkm1dQFwWbbqJMMKMZQpkYiqqwbcdA4Uurb+gjQLxLrF8V8vqBovysY7A1J6vL9YtTNF+RCgR3T5pQyMtXeNNw48Nehv6ZcDKW8sP106Is0VPEmV1VWRQv/uR5N0NXDSGQdxbsRm2Tu5/aMWFxcL/bOvB931J2cYv0veuj0R97gxvNdLFgiZ7IoCht4URFlZMXngZ7JnoiIodjoicicjgmeiIih2OiJyJyOCZ6IiKHY6InInI4JnoiIodjoicicjgmeiIih2OiJyJyOCZ6IiKHY6InInI4Ryb6quVKRjoEIqKo4chE/+MjPSIdAhFR1HBkoo+rUR4f/q1TpMMgIooKjkz0AHDdZbUjHQIRUVRwbKIvFevYj0ZEFBRmQyIih3N0om9Rq0KkQyAiijhHJ/ouTapFOgQioohzdKJXKjw/tEtEdDFxeKKPdARERJHn6ETfu2XNSIdARBRxthK9iAwUkV0ikiQiY0xeHyEiGSKyUf97wPDaPSKyR/+7J5TBBzKoXd3iXB0RUVQKmOhFJAbARACDALQBMFxE2pjMOl0p1UH/+0x/bzUA/wZwJYCuAP4tIlVDFr0NT/ZrWZyrIyKKOnZK9F0BJCmlkpVS2QCmARhqc/kDACxUSmUppY4DWAhgYOFCLZzL61cqztUREUUdO4m+PoADhudp+jRvt4jIZhGZKSINg3mviIwUkQQRScjIyLAZuj3xcRxiSUSXtlB1xv4MIE4pdQW0UvuXwbxZKfWJUipeKRVfs2ZoO1ArlSmJyfd2CekyiYguJnYSfTqAhobnDfRpbkqpTKXUBf3pZwA6230vERGFl51Evw5ACxFpIiKlAAwDMNs4g4gYh7fcAGCH/ng+gP4iUlXvhO2vTytWVzWtXtyrJCKKGrGBZlBK5YrIaGgJOgbAJKXUNhF5BUCCUmo2gL+LyA0AcgFkARihvzdLRF6FdrIAgFeUUllh+Bx+lSkZU9yrJCKKGrba6JVSc5VSLZVSzZRSr+vTXtSTPJRSzyql2iql2iul+iqldhreO0kp1Vz/mxyejxGcNnUrYevLAyIdhl+j+zaPdAhE5MeQKy6e63QcfWWsmUf6NMOMUVehQmn/lZmxQy5DL5tX1v53eMdQhAYAaFC1LACgZ4saIVsmWatfpWykQ6CLVMeGVSIdgm2XTKJ/4fo2+OuVjfDMwNYo75Xkt5mU7h+4uqmt5b564+UYUogrcKuXL4V1z1/nMW103+aoF0Ti6d7s4ul7qFQmYCthoYzoHlek939RxBFZPZsX7wm5dJT8oM6jfZuFdflXXwQFnbz8i+dmWtGx1xSD+3s2wX9uaucxbfboHnhusG/i/+GR7gAA8VrG+Nva+yy3pZ973o/oHmfZEXxjx/qoWbG0x7QWtc2X1dXidstf338lalfSllG5bEnLOFz+O7wj+rSKzP1/qpYvFZbltqlbtAvixPtLBvD0gFa23399GKrv7epXxl3dGmPXawNRzWu7dSnidSHDuzYMPFOYNK1R3va8Jcy+mDBpVtN+XEbNL6Lfu7hkEr2ZKxpUwcheWsnEWIWvWUFLnsbzdcq4IbilcwOfZYifHfKlG9rik7s7m7422KQWICJoXaciAKBquVK4Pb4BRnSPw7Au5gdnTImCdc8YdZX78aDL61jGFMzhUz1MydnlsWt8+yGaBJEMAEDBt1QVW0IwoG3hfzP4UYv+kfYmVfUbO2rX/zWuXq7Q6/P2/cPd8eqNl6N0bAzKWgwkiG9c1eck4M+OVwZi3M3t8FAv85L4u7f7FmKMunqdYDo0DP5OJte1sf+dFFeeLx1bAjd38j2u7bj2Ivpd6ks60RvNfPgqvHXLFfjs7ng0rGb/oLXaISeP6KK/7jlDy9oVMP/xXujc2PdAubJJNYwd0gYzRl2FVnUq4q1b2+OlG9ripo5mFyJr7ojXTgLGE9UHf+3kJ96CeFLGDbGc7+kBrfDzYz3x8Z3WyyqM9g2r4M1b2uHDv3XCP/v7lpx/Gt2jUMsddHkdxOvb9JsHrsT42zsUJUxbFjzRC2VKxuDd29tj6oPdLOcL9qTj7/eOOzWqAgD4z83tfGqo3p4d1Nr9uGypGAzr2ghxFifSQMmuZ4sauKF9PQDAc4Nbo1+b2pg20vozm7Gbu81+Ga5htbLoa1IbDdTHsuGFfn6nj72+DR7u7Xny864lLnmqj9912PHi9W0wspdnc3CXuOK77RcTva5u5bK4vUtDj1KH2Y55betaAIDHr2uhv6+M6fKs2trLlIxBK73UbpQybghqVyqDUrElfKrn/moNT/Rrid2vDfJpfrLSsKr1gTHbkGRv7lQf9aqUxcDL6wZsHoktEfgQNs5xR5dGPjWaFf/qi4VP9EKlMgVNUD2aVzc9uM1ULBOLe3s0AaAligqlYz3aeb9/uLvpshpULaf/958wvri3i8cPHPRsXgMta2vf482dGqBelbJY/M/e6KgnYqP/3RXvMy2uejn0MynhPje4tcdz76/+H9e1xMIneqFl7YoYeHkd7HtjsGXMD/VuhmcGtsJMQ22vsJQCuunNkFc11bZrB72GY2zq8tsXY7GbXHdZLc/ZBD41mWY1K2DyvV1xcyet0OM67iZYDIQYO+QypIwb4tNkuPb5azHlgStRtXwppIwbgru6NUYJw/778+ie7qZb13PjyfGDv3quz24z3309m+C5wZe5n68ccw2+uu9KdG9WHTUqlPI49sIhPD1kDvbhnZ1w8mwOalQojbu6NUb1CqWR79Ups/ifvdGsplYqiQlxHbRToypYn3oCz+s7jYigVKy2jpdvaItXftkOY95d9nRfzN16CJNX7kPP5jXQv21tfLlqv+myr2hQxbSU/9PoHvjf0r14Z8Fu0/f956Z2eOb7zT7T29arhG0HTwEALqtbCfWqlPXY2Y1cCdfo2we0EuOG1OO46cM/TN9nNOSKuhhyRUH8t3RqgOV7junPPBt54htXxcyHtQP6q/u6onWdiuj6n8Xu10d0j8MXf6QA0Jot+rSqhfcWFnx+syajZjUr4OHezTDy60TLGKuXL4XMM9mYMao7zlzIxcLtRzDr0R64ceJKbblei61eoTTSjp9zP48pIWhRu6CgYCwE/Gtga3y6PBljBrZ2NzM90qdow3Q3v9QfL8/ejvt6xqFC6Vj0bV0TdStrJ8UyJWPc+8vb83ehRoXSWPBELzz/4xbM23oYvVrWxLLdBfeuEssyveDpAa2Qn68wfuFuCMSn0/k+/ST+7u0d8K5eW7uQm4fSsZ4nhFmP9kBq1ll37cNbrYplUKuieeEMANo1qGz6vEq5kjhxNgfdm9XA+Nva4/jZbABaM9/b83dZLs+KqyYyxU9NMJSY6P247rJaWLrb8yZrpWNjUKuStnNVr1Da7G3uJA9oVeZXh7ZFlXKl8NjUDYWOZdOL/bHn6Gks3HEE61NPoINJyfGe7nG4x2sUSqPq5TCqdzOMMlRP29SthO2HTtled8kYz1rGv//SBi//vN393CzpAVq7tSvRl4otUeidumMj8yru1S1qoFUdrbbRzaTT2+oc+9g1zT1GVZkNo728fsEB/8/+LX0W2LSGeUdc/7Z18PPonvjLBysAADW89pGaFUsj80w2YkoI4mqU9zmxejfbfHp3ZyzYdgRjZ201/zAGD/dphof7BDcaZtGTvdDE67MsfboPer+9BIB2r6jxhvZ7V5L3NvXBbmhaszyqlS+F/w7viB2HTqNdg8rYfvAUJq/chxmJaZbfh4iWMHccOoXxC30LE1ZNjN5JHtBqGR0shj2a1basvH7T5fjgtyT3c9cJWADTvjork0bEIznjjPv51/d3xa7Dpy3nH9jWun+tKJjo/bizW2O88NO2oN5jVnW966o4HMg6CwDoZJG0AqlcriTi46qhfcMq6N6sRpFGX/z8WE/kB/k7i66SY5e4qri3RxOPRG+ld8uaWJmUiZPncoLqBPbn8vqV8Eif5tiUdgJP9muJ0rExSBh7nU9CBYCBXp3SQzvUw5JdGbijS0PTUUqf3h2P3Uc8D8KbO9XHlfpJ5LG+zfHAVwn4+M5O6NOqls/7XYylwoSxnkNov7yvK5bvOebTkZr0+iBMWZuKYV0aeUyvVbEM7uzW2G+iv7Z1LSzeedTydW9rnrsWK/Ycw+B2dVG2VEGyXPRkb1QuW9JnNJgdVxmG+sbGlHBvgzb1Krk7qgVabffa8Uv9Lst4QigVE7rW5RkP2W/C+tuVjfG3Kxu7n7t+f9rsZFWnUhkcPnXe/Xz7KwPQ5kXtTi/XtK6NawytcVe3qImrW5g3R+55fVDIWwBcmOj9EBEkjr3Obxu5tzl/v9p0esNq5TD/8V5oWsihXC4lY0oU+ScSY0oIYrxSb7Djs/9vWAckHf3TXYUFgNs6N8CMxDTDekrgpRva4InpmyzK/Pa9duPlALSTL+A5asksyQNaie+nR3vghZ+2om29yujcuBpu6mhdGuvXprZpu7nLdW1q++3AtqN2pTK41aREGBtTAndfFWf5vmkju2Fr+knT1z66szPO5eQFFYNZqTRcwwXdpWHRarvjb2uPf87Y5H7d39E17hb/Hc52TB/ZDYdOnkdsCE4aZs1Py//VF0oBLcfOAwCUK1W4tFoyhCc1b0z0AVg1z1jxN2LHrBP2/4Z1wKlzOUHHZUeP5oEvqEoYex3Kloyx3ZnrMrRDwUig6etS3Y/HDrkMr83R7mlXvUIpnDCcCIrCleCD1b5hFcwe3TMkMURSt6bVTZunAK25x99InWjhSpLGWgRQUEo2q2SGooB7pY2bGj7UqykOHD9r+fqkEV3wxR8pqGhSYw9ngg4VJvoIMybMUFr/Qj+ULx34Zm5WpeHCENGuKB7RPQ5Ld2egb6ta+HFDmt/3vHt7e6RkWh9gFBnv3NYeszcdDMmyvPO3q9mqcfVy2J951k8nbfF51mKQgEt8XLWL+keMmOhDwLUjF+PFfAEFczFNUXkfqLExJdwXkwQ6iAt7sUq0e+e29pYXO10Mbu3cwLSJqTDu6NIQP286iL9eqfU/dGtaHZPv7YJT53Lwj2kb3ceNayTK37o1RmJKsd/ktsheHdoWbepF50+XMtGHgCuVucbYX2pu6FAPifuP45mBvmOKXbdvGN61kc9rThaqJOkEtSuVwcIne3tM69uqFn7ZrNUYXIm+crmS7j6QhIsw0d9l6GP5+7UtIna7ETNM9CFQooRgxb/6hrQZ5GJSpmQM3rz1CtPX6lUpW+QOzEgoU1Jrd60YZN8FBS9UTTd/v7ZFSJYTCk/2axnpEDxwLw4Rswt+6OI1+PK6eHbQOdx1VeE6gSmwxtW0EWidTG4HEqyezWtEXXKNJkz0ZEvHRlUwtEM9/COKSk3hVKKE4KHe4b0V76WuXYPKWPp0HzTyM1LNTml/2dN9CzX2/1LCRE+2lIwpgf8bFrofWCECgMbVza8ridOn22kObRTCO4c6FRM9RZ2yJWMs71VCl4bHrmmOTo2r8pfWQoSJnqLOjlcHRjoEirDYEFwBTgWi/5IuIiIqEiZ6IiKHY6InInI4JnoiIodjoicicjiOuiEqBu/d0R61K1n/hB1RODHRExUDfz94QhRubLohInI4JnoiIodjoicicjgmeiIih2OiJyJyOCZ6IiKHY6InInI4JnoiIocTpVSkY/AgIhkA9hdhETUAHAtROKHEuILDuILDuILjxLgaK6VMb+IfdYm+qEQkQSkVH+k4vDGu4DCu4DCu4FxqcbHphojI4ZjoiYgczomJ/pNIB2CBcQWHcQWHcQXnkorLcW30RETkyYkleiIiMmCiJyJyOMckehEZKCK7RCRJRMYUw/oaisjvIrJdRLaJyD/06S+JSLqIbNT/Bhve86we3y4RGRCu2EUkRUS26OtP0KdVE5GFIrJH/19Vny4iMkFf92YR6WRYzj36/HtE5J4ixtTKsE02isgpEXk8EttLRCaJyFER2WqYFrLtIyKd9e2fpL9XihDX2yKyU1/3jyJSRZ8eJyLnDNvt40Drt/qMhYwrZN+biDQRkTX69OkiUqoIcU03xJQiIhsjsL2sckPk9jGl1EX/ByAGwF4ATQGUArAJQJswr7MugE7644oAdgNoA+AlAE+ZzN9Gj6s0gCZ6vDHhiB1ACoAaXtPeAjBGfzwGwJv648EA5gEQAN0ArNGnVwOQrP+vqj+uGsLv6zCAxpHYXgB6AegEYGs4tg+Atfq8or93UBHi6g8gVn/8piGuOON8XssxXb/VZyxkXCH73gB8B2CY/vhjAA8XNi6v18cDeDEC28sqN0RsH3NKib4rgCSlVLJSKhvANABDw7lCpdQhpdR6/fFpADsA1PfzlqEApimlLiil9gFI0uMurtiHAvhSf/wlgBsN079SmtUAqohIXQADACxUSmUppY4DWAhgYIhiuRbAXqWUvyugw7a9lFLLAGSZrK/I20d/rZJSarXSjsivDMsKOi6l1AKlVK7+dDUAv79JGGD9Vp8x6Lj8COp700ui1wCYGcq49OXeDmCqv2WEaXtZ5YaI7WNOSfT1ARwwPE+D/6QbUiISB6AjgDX6pNF6FWySobpnFWM4YlcAFohIooiM1KfVVkod0h8fBlA7AnG5DIPnARjp7QWEbvvU1x+HOj4AuA9a6c2liYhsEJGlInK1IV6r9Vt9xsIKxfdWHcAJw8ksVNvragBHlFJ7DNOKfXt55YaI7WNOSfQRIyIVAHwP4HGl1CkAHwFoBqADgEPQqo/FradSqhOAQQAeFZFexhf1UkBExtXq7a83AJihT4qG7eUhktvHiog8DyAXwLf6pEMAGimlOgJ4EsAUEalkd3kh+IxR9715GQ7PwkSxby+T3FCk5RWFUxJ9OoCGhucN9GlhJSIloX2R3yqlfgAApdQRpVSeUiofwKfQqqz+Ygx57EqpdP3/UQA/6jEc0at8rurq0eKOSzcIwHql1BE9xohvL12otk86PJtXihyfiIwAcD2Av+kJAnrTSKb+OBFa+3fLAOu3+oxBC+H3lgmtqSLWJN5C0Zd1M4DphniLdXuZ5QY/ywv/PmancyHa/wDEQuuoaIKCjp62YV6nQGsbe99rel3D4yegtVcCQFt4dlIlQ+ugCmnsAMoDqGh4/Ae0tvW34dkR9Jb+eAg8O4LWqoKOoH3QOoGq6o+rhWC7TQNwb6S3F7w650K5feDbUTa4CHENBLAdQE2v+WoCiNEfN4V2oPtdv9VnLGRcIfveoNXujJ2xjxQ2LsM2Wxqp7QXr3BCxfSxsibC4/6D1XO+GdqZ+vhjW1xNa1WszgI3632AAXwPYok+f7XVAPK/HtwuGXvJQxq7vxJv0v22u5UFrC10MYA+ARYYdRgBM1Ne9BUC8YVn3QetMS4IhORchtvLQSnCVDdOKfXtBq9IfApADrX3z/lBuHwDxALbq7/kA+hXohYwrCVo7rWsf+1if9xb9+90IYD2AvwRav9VnLGRcIfve9H12rf5ZZwAoXdi49OlfABjlNW9xbi+r3BCxfYy3QCAicjintNETEZEFJnoiIodjoicicjgmeiIih2OiJyJyOCZ6IiKHY6InInK4/weSmu2tJtBRHwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 展示模型损失优化的过程\n",
    "plt.plot(torch.tensor(lossi))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([32561, 2]), torch.Size([32561, 2]), torch.Size([32561]))"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 关闭梯度追踪\n",
    "with torch.no_grad():\n",
    "    logits = model(x)  \n",
    "    probs = F.softmax(logits, dim=1)\n",
    "    pred = torch.where(probs[:, 1] > 0.5, 1, 0)\n",
    "logits.shape, probs.shape, pred.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(0.2928), tensor(0.5902), tensor(0.3914))"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from torcheval.metrics.functional.classification import binary_recall\n",
    "from torcheval.metrics.functional import binary_precision, binary_f1_score\n",
    "binary_recall(pred, y), binary_precision(pred, y), binary_f1_score(pred, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 展示如何利用排序数据得到对偏好的估计\n",
    "# 此处只做模型结构展示，并不训练和使用模型\n",
    "class PreferenceModel:\n",
    "    \n",
    "    def __init__(self, pref):\n",
    "        self.pref = pref\n",
    "        \n",
    "    def __call__(self, x0, x1):\n",
    "        self.out = torch.concat((self.pref(x0), self.pref(x1)), dim=1)\n",
    "        return self.out\n",
    "    \n",
    "    def parameters(self):\n",
    "        return self.pref.parameters()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 预测偏好的模型\n",
    "preference = Linear(5, 1)\n",
    "# 将两个数据的偏好组合在一起，以便和排序数据结合在一起\n",
    "p_model = PreferenceModel(preference)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[0.0679, 1.1044]]), tensor([[0.2618, 0.7382]]))"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 随机选取x0和x1\n",
    "x0 = x[[0]]\n",
    "x1 = x[[1]]\n",
    "p_logits = p_model(x0, x1)\n",
    "# 得到有偏好推导出来的排序概率，该数据可以与观测到的实际排序相结合，定义模型损失\n",
    "p_probs = F.softmax(p_logits, dim=1)\n",
    "p_logits, p_probs"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
